knitr::opts_chunk$set(fig.align="center")
library(rstanarm)
library(tidyverse)
library(tidybayes)
library(modelr)
library(ggplot2)
library(magrittr)
library(emmeans)
library(bayesplot)
library(brms)
library(gganimate)
theme_set(theme_light())
source('helper_functions.R')
In our experiment, we used a visualization recommendation algorithm (composed of one search algorithm and one oracle algorithm) to generate visualizations for the user on one of two datasets. We then measured the user’s time to complete each of four tasks: 1. Find Extremum 2. Retrieve Value 3. Prediction 4. Exploration
Given a search algorithm (bfs or dfs), an oracle (compassql or dziban), and a dataset (birdstrikes or movies), we would like to predict the time it takes the average user to complete each task. In addition, we would like to know if the choice of search algorithm and oracle has any meaningful impact on a user’s completion time for each of these four tasks,
time_data = read.csv('split_by_participant_groups/completion_time.csv')
time_data <- time_data %>%
mutate(
dataset = as.factor(dataset),
oracle = as.factor(oracle),
search = as.factor(search),
task = as.factor(task)
)
time_data$condition <- paste(time_data$oracle, time_data$search)
task_list = c("1. Find Extremum",
"2. Retrieve Value",
"3. Prediction",
"4. Exploration")
seed = 12
The prior was derived from pilot studies. It describes the distribution of time (in seconds) needed to complete any given task. The lognormal family was selected to prevent our model from predicting completion times of less than zero seconds.
model <- brm(
formula = bf(
completion_time ~ oracle * search * dataset + task + participant_group + (1 | participant_id)
),
prior = prior(normal(360.48, 224.40), class = "Intercept"),
chains = 2,
cores = 2,
iter = 2500,
warmup = 1000,
data = time_data,
file = "models/time",
family = lognormal(),
seed = seed
)
In the summary table, we want to see Rhat values close to 1.0 and Bulk_ESS in the thousands.
summary(model)
## Family: lognormal
## Links: mu = identity; sigma = identity
## Formula: completion_time ~ oracle * search * dataset + task + participant_group + (1 | participant_id)
## Data: time_data (Number of observations: 264)
## Samples: 2 chains, each with iter = 2500; warmup = 1000; thin = 1;
## total post-warmup samples = 3000
##
## Group-Level Effects:
## ~participant_id (Number of levels: 66)
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sd(Intercept) 0.20 0.04 0.12 0.28 1.01 917 1822
##
## Population-Level Effects:
## Estimate Est.Error l-95% CI u-95% CI Rhat
## Intercept 5.31 0.12 5.08 5.55 1.00
## oracledziban 0.00 0.14 -0.27 0.28 1.00
## searchdfs 0.11 0.14 -0.17 0.39 1.00
## datasetmovies -0.13 0.14 -0.40 0.15 1.00
## task2.RetrieveValue 0.04 0.07 -0.09 0.18 1.00
## task3.Prediction 1.11 0.07 0.97 1.25 1.00
## task4.Exploration 1.21 0.07 1.07 1.34 1.00
## participant_groupstudent 0.05 0.07 -0.09 0.19 1.00
## oracledziban:searchdfs -0.11 0.20 -0.51 0.28 1.00
## oracledziban:datasetmovies 0.13 0.20 -0.27 0.51 1.00
## searchdfs:datasetmovies -0.03 0.20 -0.41 0.34 1.00
## oracledziban:searchdfs:datasetmovies -0.08 0.27 -0.60 0.46 1.00
## Bulk_ESS Tail_ESS
## Intercept 1706 1942
## oracledziban 1338 1761
## searchdfs 1501 1881
## datasetmovies 1475 1762
## task2.RetrieveValue 3904 2504
## task3.Prediction 4150 2641
## task4.Exploration 4166 2500
## participant_groupstudent 2460 2314
## oracledziban:searchdfs 1488 1944
## oracledziban:datasetmovies 1340 1779
## searchdfs:datasetmovies 1461 2152
## oracledziban:searchdfs:datasetmovies 1387 1942
##
## Family Specific Parameters:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma 0.40 0.02 0.37 0.45 1.00 2372 2109
##
## Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
Trace plots help us check whether there is evidence of non-convergence for model.
plot(model)
In our pairs plots, we want to make sure we don’t have highly correlated parameters (highly correlated parameters means that our model has difficulty differentiating the effect of such parameters).
pairs(
model,
pars = c("b_Intercept",
"b_datasetmovies",
"b_oracledziban",
"b_searchdfs",
"b_task2.RetrieveValue",
"b_task3.Prediction",
"b_task4.Exploration"),
fixed = TRUE
)
Using draws from the posterior, we can visualize parameter effects and average response. Be sure to apply an exponential transform to our log-transformed times to make it interpretable! The thicker, shorter line represents the 95% credible interval, while the thinner, longer line represents the 50% credible interval.
draw_data <- time_data %>%
add_fitted_draws(model, seed = seed, re_formula = NA)
for (task_name in task_list) {
draw_data_sub <- subset(draw_data, task == task_name)
plot <- posterior_draws_plot(draw_data_sub, "dataset", FALSE, "Predicted Mean Completion Time (seconds)", "Oracle/Search Combination")
plot
filename = gsub("^.*\\.","", task_name )
filename = gsub(" ", "_", filename)
filename = paste("time", filename, sep = "")
ggsave(
file = paste(filename, ".png", sep = ""),
plot = plot,
path = "../plots/posterior_draws/time", width = 7, height = 7
)
fit_info <- draw_data_sub %>% group_by(search, oracle, dataset) %>% mean_qi(.value, .width = c(.95, .5))
fit_info
write.csv(fit_info,
paste("../plot_data/posterior_draws/time/", filename, ".csv", sep = ""),
row.names = FALSE)
}
plot_data <- draw_data
plot_data <- plot_data[plot_data$task %in% c("1. Find Extremum", "2. Retrieve Value"),]
plot_data$task <- factor(plot_data$task)
plot_data$oracle<- gsub('compassql', 'CompassQL', plot_data$oracle)
plot_data$oracle<- gsub('dziban', 'Dziban', plot_data$oracle)
plot_data$search<- gsub('bfs', 'BFS', plot_data$search)
plot_data$search<- gsub('dfs', 'DFS', plot_data$search)
plot_data$condition <- paste(plot_data$oracle, plot_data$search, sep=" + ")
draw_plot <- posterior_draws_plot(plot_data, "dataset", TRUE, "Predicted Mean Completion Time (seconds)", "Oracle/Search Combination") + theme(axis.text.y=element_text(size=12))
draw_plot
## # A tibble: 32 x 10
## # Groups: search, oracle, dataset [8]
## search oracle dataset task .value .lower .upper .width .point .interval
## <chr> <chr> <fct> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 BFS Compas… birdstr… 1. Find… 229. 179. 286. 0.95 mean qi
## 2 BFS Compas… birdstr… 2. Retr… 239. 187. 301. 0.95 mean qi
## 3 BFS Compas… movies 1. Find… 202. 160. 252. 0.95 mean qi
## 4 BFS Compas… movies 2. Retr… 211. 165. 263. 0.95 mean qi
## 5 BFS Dziban birdstr… 1. Find… 229. 181. 286. 0.95 mean qi
## 6 BFS Dziban birdstr… 2. Retr… 239. 187. 299. 0.95 mean qi
## 7 BFS Dziban movies 1. Find… 228. 181. 283. 0.95 mean qi
## 8 BFS Dziban movies 2. Retr… 238. 189. 294. 0.95 mean qi
## 9 DFS Compas… birdstr… 1. Find… 255. 202. 318. 0.95 mean qi
## 10 DFS Compas… birdstr… 2. Retr… 266. 210. 331. 0.95 mean qi
## # … with 22 more rows
Next, we want to see if there is any significant difference in completion time between the two search algorithms (bfs and dfs) and the two oracles (dzbian and compassql).
predictive_data <- time_data %>%
add_fitted_draws(model, seed = seed, re_formula = NA)
search_differences <- expected_diff_in_mean_plot(predictive_data, "search", "Difference in Mean Completion Time (Seconds)", "Task", NULL)
## `summarise()` regrouping output by 'search', 'task', 'dataset' (override with `.groups` argument)
ggsave(file="search_time_differences.png", plot=search_differences$plot, path = "../plots/comparisons/time", width = 7, height = 7)
search_differences$plot
We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero.
write.csv(search_differences$intervals, "../plot_data/comparisons/time/search_time_differences.csv", sep="", row.names = FALSE)
## Warning in write.csv(search_differences$intervals, "../plot_data/comparisons/
## time/search_time_differences.csv", : attempt to set 'sep' ignored
search_differences$intervals
## # A tibble: 8 x 8
## # Groups: search [1]
## search task difference .lower .upper .width .point .interval
## <chr> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 bfs - dfs 1. Find Extremum -4.72 -52.5 41.7 0.95 mean qi
## 2 bfs - dfs 2. Retrieve Value -4.93 -55.3 43.8 0.95 mean qi
## 3 bfs - dfs 3. Prediction -14.5 -160. 127. 0.95 mean qi
## 4 bfs - dfs 4. Exploration -15.8 -178. 140. 0.95 mean qi
## 5 bfs - dfs 1. Find Extremum -4.72 -20.7 11.6 0.5 mean qi
## 6 bfs - dfs 2. Retrieve Value -4.93 -21.8 12.1 0.5 mean qi
## 7 bfs - dfs 3. Prediction -14.5 -63.7 35.0 0.5 mean qi
## 8 bfs - dfs 4. Exploration -15.8 -69.1 38.9 0.5 mean qi
Let’s do the above, but split it by datasets
search_differences_split_by_dataset <- expected_diff_in_mean_plot(predictive_data, "search", "Difference in Mean Completion Time (Seconds)", "Task", "dataset")
## `summarise()` regrouping output by 'search', 'task', 'dataset' (override with `.groups` argument)
ggsave(file="split_by_dataset_search_time_differences.png", plot=search_differences_split_by_dataset$plot, path = "../plots/comparisons/time", width = 7, height = 7)
search_differences_split_by_dataset$plot
Check intervals
write.csv(search_differences_split_by_dataset$intervals, "../plot_data/comparisons/time/search_time_differences_split_by_dataset.csv", row.names = FALSE)
search_differences_split_by_dataset$intervals
## # A tibble: 16 x 9
## # Groups: search, dataset [2]
## search dataset task difference .lower .upper .width .point .interval
## <chr> <fct> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 bfs - d… birdstr… 1. Find E… -13.1 -60.4 34.9 0.95 mean qi
## 2 bfs - d… birdstr… 2. Retrie… -13.7 -62.6 36.3 0.95 mean qi
## 3 bfs - d… birdstr… 3. Predic… -40.0 -182. 103. 0.95 mean qi
## 4 bfs - d… birdstr… 4. Explor… -43.8 -199. 113. 0.95 mean qi
## 5 bfs - d… movies 1. Find E… 3.65 -37.8 45.3 0.95 mean qi
## 6 bfs - d… movies 2. Retrie… 3.84 -39.6 47.6 0.95 mean qi
## 7 bfs - d… movies 3. Predic… 11.0 -118. 138. 0.95 mean qi
## 8 bfs - d… movies 4. Explor… 12.2 -130. 155. 0.95 mean qi
## 9 bfs - d… birdstr… 1. Find E… -13.1 -29.4 2.58 0.5 mean qi
## 10 bfs - d… birdstr… 2. Retrie… -13.7 -30.7 2.64 0.5 mean qi
## 11 bfs - d… birdstr… 3. Predic… -40.0 -89.5 7.77 0.5 mean qi
## 12 bfs - d… birdstr… 4. Explor… -43.8 -98.3 8.71 0.5 mean qi
## 13 bfs - d… movies 1. Find E… 3.65 -11.0 18.2 0.5 mean qi
## 14 bfs - d… movies 2. Retrie… 3.84 -11.4 19.0 0.5 mean qi
## 15 bfs - d… movies 3. Predic… 11.0 -33.3 55.4 0.5 mean qi
## 16 bfs - d… movies 4. Explor… 12.2 -36.5 60.1 0.5 mean qi
oracle_differences <- expected_diff_in_mean_plot(predictive_data, "oracle", "Difference in Mean Completion Time (Seconds)", "Task", NULL)
## `summarise()` regrouping output by 'oracle', 'task', 'dataset' (override with `.groups` argument)
ggsave(file="oracle_time_differences.png", plot=oracle_differences$plot, path = "../plots/comparisons/time", width = 7, height = 7)
oracle_differences$plot
We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero.
write.csv(oracle_differences$intervals, "../plot_data/comparisons/time/oracle_time_differences.csv", sep="", row.names = FALSE)
## Warning in write.csv(oracle_differences$intervals, "../plot_data/comparisons/
## time/oracle_time_differences.csv", : attempt to set 'sep' ignored
oracle_differences$intervals
## # A tibble: 8 x 8
## # Groups: oracle [1]
## oracle task difference .lower .upper .width .point .interval
## <chr> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 dziban - compa… 1. Find Extr… -2.42 -52.5 45.1 0.95 mean qi
## 2 dziban - compa… 2. Retrieve … -2.51 -54.3 46.8 0.95 mean qi
## 3 dziban - compa… 3. Prediction -7.29 -161. 137. 0.95 mean qi
## 4 dziban - compa… 4. Explorati… -8.03 -177. 150. 0.95 mean qi
## 5 dziban - compa… 1. Find Extr… -2.42 -19.1 14.7 0.5 mean qi
## 6 dziban - compa… 2. Retrieve … -2.51 -20.0 15.4 0.5 mean qi
## 7 dziban - compa… 3. Prediction -7.29 -58.0 44.6 0.5 mean qi
## 8 dziban - compa… 4. Explorati… -8.03 -64.4 49.0 0.5 mean qi
Let’s do the above, but split it by datasets
oracle_differences_subset_split_by_dataset <- expected_diff_in_mean_plot(predictive_data, "oracle", "Difference in Mean Completion Time (Seconds)", "Task", "dataset")
## `summarise()` regrouping output by 'oracle', 'task', 'dataset' (override with `.groups` argument)
ggsave(file="split_by_dataset_oracle_time_differences.png", plot=oracle_differences_subset_split_by_dataset$plot, path = "../plots/comparisons/time", width = 7, height = 7)
oracle_differences_subset_split_by_dataset$plot
Check intervals
write.csv(oracle_differences_subset_split_by_dataset$intervals, "../plot_data/comparisons/time/oracle_time_differences_split_by_dataset.csv", row.names = FALSE)
oracle_differences_subset_split_by_dataset$intervals
## # A tibble: 16 x 9
## # Groups: oracle, dataset [2]
## oracle dataset task difference .lower .upper .width .point .interval
## <chr> <fct> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 dziban -… birdstr… 1. Find… -12.7 -59.7 34.5 0.95 mean qi
## 2 dziban -… birdstr… 2. Retr… -13.2 -62.5 36.0 0.95 mean qi
## 3 dziban -… birdstr… 3. Pred… -38.3 -181. 106. 0.95 mean qi
## 4 dziban -… birdstr… 4. Expl… -42.1 -199. 116. 0.95 mean qi
## 5 dziban -… movies 1. Find… 7.82 -35.2 49.7 0.95 mean qi
## 6 dziban -… movies 2. Retr… 8.15 -37.3 51.9 0.95 mean qi
## 7 dziban -… movies 3. Pred… 23.8 -107. 150. 0.95 mean qi
## 8 dziban -… movies 4. Expl… 26.1 -118. 165. 0.95 mean qi
## 9 dziban -… birdstr… 1. Find… -12.7 -28.6 3.23 0.5 mean qi
## 10 dziban -… birdstr… 2. Retr… -13.2 -29.4 3.42 0.5 mean qi
## 11 dziban -… birdstr… 3. Pred… -38.3 -85.9 9.75 0.5 mean qi
## 12 dziban -… birdstr… 4. Expl… -42.1 -94.3 10.5 0.5 mean qi
## 13 dziban -… movies 1. Find… 7.82 -5.83 22.2 0.5 mean qi
## 14 dziban -… movies 2. Retr… 8.15 -6.05 23.0 0.5 mean qi
## 15 dziban -… movies 3. Pred… 23.8 -17.8 66.7 0.5 mean qi
## 16 dziban -… movies 4. Expl… 26.1 -19.6 72.9 0.5 mean qi
participant_group_differences <- expected_diff_in_mean_plot(predictive_data, "participant_group", "Difference in Mean Completion Time (Seconds)", "Task", NULL)
## `summarise()` regrouping output by 'participant_group', 'task', 'dataset' (override with `.groups` argument)
ggsave(file="participant_group_time_differences.png", plot=participant_group_differences$plot, path = "../plots/comparisons/time", width = 7, height = 7)
participant_group_differences$plot
Let’s do the above, but split it by datasets
participant_group_differences_split_by_dataset <- expected_diff_in_mean_plot(predictive_data, "participant_group", "Difference in Mean Completion Time (Seconds)", "Task", "dataset")
## `summarise()` regrouping output by 'participant_group', 'task', 'dataset' (override with `.groups` argument)
ggsave(file="split_by_dataset_participant_group_time_differences.png", plot=participant_group_differences_split_by_dataset$plot, path = "../plots/comparisons/time", width = 7, height = 7)
participant_group_differences_split_by_dataset$plot
We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero.
write.csv(participant_group_differences$intervals, "../plot_data/comparisons/time/participant_group_time_differences.csv", sep="", row.names = FALSE)
## Warning in write.csv(participant_group_differences$intervals, "../plot_data/
## comparisons/time/participant_group_time_differences.csv", : attempt to set 'sep'
## ignored
participant_group_differences$intervals
## # A tibble: 8 x 8
## # Groups: participant_group [1]
## participant_group task difference .lower .upper .width .point .interval
## <chr> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 student - profess… 1. Find … 10.4 -20.9 42.1 0.95 mean qi
## 2 student - profess… 2. Retri… 10.8 -22.1 43.8 0.95 mean qi
## 3 student - profess… 3. Predi… 31.5 -65.0 127. 0.95 mean qi
## 4 student - profess… 4. Explo… 34.8 -69.6 142. 0.95 mean qi
## 5 student - profess… 1. Find … 10.4 -0.444 21.2 0.5 mean qi
## 6 student - profess… 2. Retri… 10.8 -0.472 22.1 0.5 mean qi
## 7 student - profess… 3. Predi… 31.5 -1.38 64.6 0.5 mean qi
## 8 student - profess… 4. Explo… 34.8 -1.50 71.3 0.5 mean qi
Plot all of the posterior draws on one plot
plot <- draw_data %>% ggplot(aes(
x = .value,
y = task,
fill = search,
alpha = 0.5
)) + stat_halfeye(.width = c(.95, .5)) +
labs(x = "Time (Seconds)", y = "Task") + facet_grid(. ~ dataset)
plot
## Saving 7 x 5 in image
#Code for additional plots (mostly subsets)
predictive_data_subset <- predictive_data[predictive_data$task %in% c("1. Find Extremum", "2. Retrieve Value"),]
predictive_data_subset$task <- factor(predictive_data_subset$task)
search_differences_subset <- expected_diff_in_mean_plot(predictive_data_subset, "search", "Expected Time Difference (Seconds)", "Task", NULL)
## `summarise()` regrouping output by 'search', 'task', 'dataset' (override with `.groups` argument)
ggsave(file="search_time_differences_subset.png", plot=search_differences_subset$plot, path = "../plots/comparisons/time", width = 7, height = 7)
search_differences_subset$plot
diff_in_search_prediction_split_by_dataset_subset <- expected_diff_in_mean_plot(predictive_data_subset, "search", "Expected Time Difference (Seconds)", "Task", "dataset")
## `summarise()` regrouping output by 'search', 'task', 'dataset' (override with `.groups` argument)
ggsave(file="split_by_dataset_search_time_differences_subset.png", plot=diff_in_search_prediction_split_by_dataset_subset$plot, path = "../plots/comparisons/time", width = 7, height = 7)
diff_in_search_prediction_split_by_dataset_subset$plot
oracle_differences_subset <- expected_diff_in_mean_plot(predictive_data_subset, "oracle", "Expected Time Difference (Seconds)", "Task", NULL)
## `summarise()` regrouping output by 'oracle', 'task', 'dataset' (override with `.groups` argument)
ggsave(file="oracle_time_differences_subset.png", plot=oracle_differences_subset$plot, path = "../plots/comparisons/time", width = 7, height = 7)
oracle_differences_subset$plot
oracle_differences_subset_split_by_dataset_subset <- expected_diff_in_mean_plot(predictive_data_subset, "oracle", "Expected Time Difference (Seconds)", "Task", "dataset")
## `summarise()` regrouping output by 'oracle', 'task', 'dataset' (override with `.groups` argument)
ggsave(file="split_by_dataset_oracle_time_differences_subset.png", plot=oracle_differences_subset_split_by_dataset_subset$plot, path = "../plots/comparisons/time", width = 7, height = 7)
oracle_differences_subset_split_by_dataset_subset$plot